{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import dpp\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from copy import deepcopy\n",
    "torch.set_default_tensor_type(torch.cuda.FloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['lastfm',\n",
       " 'stack_overflow',\n",
       " 'synth/hawkes1',\n",
       " 'synth/hawkes2',\n",
       " 'synth/nonstationary_poisson',\n",
       " 'synth/nonstationary_renewal',\n",
       " 'synth/self_correcting',\n",
       " 'synth/stationary_poisson',\n",
       " 'synth/stationary_renewal']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dpp.data.list_datasets()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Config\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "dataset_name = 'synth/hawkes1'  # run dpp.data.list_datasets() to see the list of available datasets\n",
    "\n",
    "# Model config\n",
    "context_size = 64                 # Size of the RNN hidden vector\n",
    "mark_embedding_size = 32          # Size of the mark embedding (used as RNN input)\n",
    "num_mix_components = 64           # Number of components for a mixture model\n",
    "rnn_type = \"GRU\"                  # What RNN to use as an encoder {\"RNN\", \"GRU\", \"LSTM\"}\n",
    "\n",
    "# Training config\n",
    "batch_size = 64        # Number of sequences in a batch\n",
    "regularization = 1e-5  # L2 regularization parameter\n",
    "learning_rate = 1e-3   # Learning rate for Adam optimizer\n",
    "max_epochs = 1000      # For how many epochs to train\n",
    "display_step = 5       # Display training statistics after every display_step\n",
    "patience = 50          # After how many consecutive epochs without improvement of val loss to stop training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "dataset = dpp.data.load_dataset(dataset_name)\n",
    "d_train, d_val, d_test = dataset.train_val_test_split(seed=seed)\n",
    "\n",
    "dl_train = d_train.get_dataloader(batch_size=batch_size, shuffle=True)\n",
    "dl_val = d_val.get_dataloader(batch_size=batch_size, shuffle=False)\n",
    "dl_test = d_test.get_dataloader(batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building model...\n"
     ]
    }
   ],
   "source": [
    "# Define the model\n",
    "print('Building model...')\n",
    "mean_log_inter_time, std_log_inter_time = d_train.get_inter_time_statistics()\n",
    "\n",
    "model = dpp.models.LogNormMix(\n",
    "    num_marks=d_train.num_marks,\n",
    "    mean_log_inter_time=mean_log_inter_time,\n",
    "    std_log_inter_time=std_log_inter_time,\n",
    "    context_size=context_size,\n",
    "    mark_embedding_size=mark_embedding_size,\n",
    "    rnn_type=rnn_type,\n",
    "    num_mix_components=num_mix_components,\n",
    ")\n",
    "opt = torch.optim.Adam(model.parameters(), weight_decay=regularization, lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def aggregate_loss_over_dataloader(dl):\n",
    "    total_loss = 0.0\n",
    "    total_count = 0\n",
    "    with torch.no_grad():\n",
    "        for batch in dl:\n",
    "            total_loss += -model.log_prob(batch).sum().item()\n",
    "            total_count += batch.size\n",
    "    return total_loss / total_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting training...\n",
      "Epoch    0: loss_train_last_batch = 68.7, loss_val = 61.5\n",
      "Epoch    5: loss_train_last_batch = 49.6, loss_val = 52.4\n",
      "Epoch   10: loss_train_last_batch = 50.4, loss_val = 51.4\n",
      "Epoch   15: loss_train_last_batch = 56.8, loss_val = 50.8\n",
      "Epoch   20: loss_train_last_batch = 49.4, loss_val = 50.4\n",
      "Epoch   25: loss_train_last_batch = 54.0, loss_val = 49.9\n",
      "Epoch   30: loss_train_last_batch = 54.8, loss_val = 49.7\n",
      "Epoch   35: loss_train_last_batch = 56.2, loss_val = 49.6\n",
      "Epoch   40: loss_train_last_batch = 50.0, loss_val = 49.5\n",
      "Epoch   45: loss_train_last_batch = 48.2, loss_val = 49.6\n",
      "Epoch   50: loss_train_last_batch = 37.0, loss_val = 49.6\n",
      "Epoch   55: loss_train_last_batch = 57.6, loss_val = 49.5\n",
      "Epoch   60: loss_train_last_batch = 58.9, loss_val = 49.5\n",
      "Epoch   65: loss_train_last_batch = 52.6, loss_val = 49.5\n",
      "Epoch   70: loss_train_last_batch = 45.8, loss_val = 49.5\n",
      "Epoch   75: loss_train_last_batch = 55.3, loss_val = 49.5\n",
      "Epoch   80: loss_train_last_batch = 28.6, loss_val = 49.6\n",
      "Epoch   85: loss_train_last_batch = 37.0, loss_val = 49.5\n",
      "Epoch   90: loss_train_last_batch = 57.2, loss_val = 49.5\n",
      "Epoch   95: loss_train_last_batch = 60.9, loss_val = 49.6\n",
      "Epoch  100: loss_train_last_batch = 52.8, loss_val = 49.6\n",
      "Epoch  105: loss_train_last_batch = 40.0, loss_val = 49.7\n",
      "Epoch  110: loss_train_last_batch = 53.0, loss_val = 49.8\n",
      "Breaking due to early stopping at epoch 113\n"
     ]
    }
   ],
   "source": [
    "# Traning\n",
    "print('Starting training...')\n",
    "\n",
    "impatient = 0\n",
    "best_loss = np.inf\n",
    "best_model = deepcopy(model.state_dict())\n",
    "training_val_losses = []\n",
    "\n",
    "for epoch in range(max_epochs):\n",
    "    model.train()\n",
    "    for batch in dl_train:\n",
    "        opt.zero_grad()\n",
    "        loss = -model.log_prob(batch).mean()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        loss_val = aggregate_loss_over_dataloader(dl_val)\n",
    "        training_val_losses.append(loss_val)\n",
    "\n",
    "    if (best_loss - loss_val) < 1e-4:\n",
    "        impatient += 1\n",
    "        if loss_val < best_loss:\n",
    "            best_loss = loss_val\n",
    "            best_model = deepcopy(model.state_dict())\n",
    "    else:\n",
    "        best_loss = loss_val\n",
    "        best_model = deepcopy(model.state_dict())\n",
    "        impatient = 0\n",
    "\n",
    "    if impatient >= patience:\n",
    "        print(f'Breaking due to early stopping at epoch {epoch}')\n",
    "        break\n",
    "\n",
    "    if epoch % display_step == 0:\n",
    "        print(f\"Epoch {epoch:4d}: loss_train_last_batch = {loss.item():.1f}, loss_val = {loss_val:.1f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluation\n",
    "model.load_state_dict(best_model)\n",
    "model.eval()\n",
    "\n",
    "# All training & testing sequences stacked into a single batch\n",
    "with torch.no_grad():\n",
    "    final_loss_train = aggregate_loss_over_dataloader(dl_train)\n",
    "    final_loss_val = aggregate_loss_over_dataloader(dl_val)\n",
    "    final_loss_test = aggregate_loss_over_dataloader(dl_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Negative log-likelihood:\n",
      " - Train: 48.6\n",
      " - Val:   49.5\n",
      " - Test:  43.7\n"
     ]
    }
   ],
   "source": [
    "print(f'Negative log-likelihood:\\n'\n",
    "      f' - Train: {final_loss_train:.1f}\\n'\n",
    "      f' - Val:   {final_loss_val:.1f}\\n'\n",
    "      f' - Test:  {final_loss_test:.1f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare the distribution of sequence lengths for real and simulated data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampled_batch = model.sample(t_max=100, batch_size=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_batch = dpp.data.Batch.from_list([s for s in dataset])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7f86529c2fd0>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEGCAYAAABy53LJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgwElEQVR4nO3df5wV9X3v8debFUWrCSqkpYBdTFYQBfmlYEUl9ccFtZKmGqHxxlgbRMV7kz76uCEmNzG5xgdJaHolEhEjBb1GEtMkRYP1V6I2KTQsCAgqioqySiJSJaWoiH7uHzNLDodzzp6Bnd1zdt/Px+M89szM9zvn82XY/Zzvd2a+o4jAzMysWj06OwAzM6svThxmZpaJE4eZmWXixGFmZpk4cZiZWSYHdXYAHaFPnz7R2NjY2WGYmdWVlStXvh4RfYvXd4vE0djYSHNzc2eHYWZWVyS9VGq9h6rMzCwTJw4zM8vEicPMzDLJ9RyHpInATUAD8L2ImFW0Xen284CdwKcjYlW6bQFwAfBaRJxYUOdbwJ8Du4Dngcsj4s0822Fmtendd9+lpaWFt99+u7NDqWu9evViwIAB9OzZs6ryuSUOSQ3AXOAcoAVYIWlJRDxVUGwS0JS+xgK3pD8BFgI3A3cU7foh4AsRsVvSN4AvAJ/Pqx1mVrtaWlo44ogjaGxsJPkeallFBNu2baOlpYVBgwZVVSfPoapTgI0R8UJE7AIWA5OLykwG7ojEcqC3pH4AEfE48B/FO42IByNid7q4HBiQWwvMrKa9/fbbHH300U4aB0ASRx99dKZeW56Joz+wuWC5JV2XtUwlfw3cX2qDpGmSmiU1b926NcMuzayeOGkcuKz/hnkmjlKRFM/hXk2Z0juXvgjsBu4qtT0i5kfEmIgY07fvPvevmJnZfsrz5HgLMLBgeQDw6n6U2Yeky0hOnJ8VfqCImaUaZ/6sXfe3adb5bZb5+te/zve//30aGhro0aMHt956K2PHjm2z3v6YMGECs2fPZsyYMVWVf/TRR5k9ezb33Xdfu8aRZ+JYATRJGgS8AkwB/qqozBJghqTFJCfFt0fElko7Ta/U+jxwZkTsbP+wLW/lfrmr+SU1qyXLli3jvvvuY9WqVRxyyCG8/vrr7Nq1q7PDyl1uQ1XpCewZwAPA08API2K9pOmSpqfFlgIvABuB24CrW+tLuhtYBgyW1CLpinTTzcARwEOSVkual1cbzMwq2bJlC3369OGQQw4BoE+fPvzxH/8xX/va1zj55JM58cQTmTZtGq0DIxMmTOBzn/scZ5xxBscffzwrVqzg4x//OE1NTXzpS18CYNOmTQwZMoTLLruM4cOHc9FFF7Fz577fkR988EFOPfVURo0axcUXX8yOHTsA+Jd/+ReGDBnC+PHj+fGPf5xLu3O9ATAilkbEcRHx4Yj4erpuXkTMS99HRFyTbh8WEc0FdadGRL+I6BkRAyLi9nT9RyJiYESMSF/TS3+6mVm+zj33XDZv3sxxxx3H1VdfzWOPPQbAjBkzWLFiBevWreOtt97aa6jo4IMP5vHHH2f69OlMnjyZuXPnsm7dOhYuXMi2bdsA2LBhA9OmTWPt2rV84AMf4Lvf/e5en/v6669zww038PDDD7Nq1SrGjBnDt7/9bd5++20+85nPcO+99/Kv//qv/OY3v8ml3b5z3MxsPx1++OGsXLmS+fPn07dvXy655BIWLlzIL37xC8aOHcuwYcP4+c9/zvr16/fUufDCCwEYNmwYJ5xwAv369eOQQw7h2GOPZfPm5CLTgQMHctpppwFw6aWX8stf/nKvz12+fDlPPfUUp512GiNGjGDRokW89NJLPPPMMwwaNIimpiYkcemll+bS7m4xO66ZWV4aGhqYMGECEyZMYNiwYdx6662sXbuW5uZmBg4cyPXXX7/XPRKtw1o9evTY8751effu5Ba14stji5cjgnPOOYe77757r/WrV6/ukMuT3eMwM9tPGzZs4LnnntuzvHr1agYPHgwk5zt27NjBj370o8z7ffnll1m2bBkAd999N+PHj99r+7hx4/jVr37Fxo0bAdi5cyfPPvssQ4YM4cUXX+T555/fUzcP7nGYWZfR0Vfm7dixg2uvvZY333yTgw46iI985CPMnz+f3r17M2zYMBobGzn55JMz7/f4449n0aJFXHnllTQ1NXHVVVfttb1v374sXLiQqVOn8s477wBwww03cNxxxzF//nzOP/98+vTpw/jx41m3bl27tLWQusNtEGPGjAk/yKl2+HJcay9PP/00xx9/fGeH0a42bdrEBRdckMsf/EpK/VtKWhkR+9w04qEqMzPLxENVVr0NJacFK2/wpHziMOvCGhsbO7y3kZV7HGZmlokTh5mZZeLEYWZmmThxmJlZJj45bmZdR9YLONpSxQUeDQ0NDBs2jN27dzNo0CDuvPNOevfunfmjFi5cSHNzMzfffHPFco2NjTQ3N9OnT5+yZW688Uauu+66zDFUyz0OM7MDcOihh7J69WrWrVvHUUcdxdy5czs7JG688cZc9+/EYWbWTk499VReeeUVAJ5//nkmTpzI6NGjOf3003nmmWcAuPfeexk7diwjR47k7LPP5re//W3FfW7bto1zzz2XkSNHcuWVV1J40/bHPvYxRo8ezQknnMD8+fMBmDlzJm+99RYjRozgk5/8ZNlyB8KJw8ysHbz33ns88sgje2a/nTZtGt/5zndYuXIls2fP5uqrk8cNjR8/nuXLl/PEE08wZcoUvvnNb1bc71e/+lXGjx/PE088wYUXXsjLL7+8Z9uCBQtYuXIlzc3NzJkzh23btjFr1qw9vaC77rqrbLkD4XMcZmYHoPXb/aZNmxg9ejTnnHMOO3bs4N/+7d+4+OKL95RrnVOqpaWFSy65hC1btrBr1y4GDRpUcf+PP/74ngcynX/++Rx55JF7ts2ZM4ef/OQnAGzevJnnnnuOo48+ep99VFuuWu5xmJkdgNZv9y+99BK7du1i7ty5vP/++/Tu3ZvVq1fveT399NMAXHvttcyYMYMnn3ySW2+9da8p18spNVX6o48+ysMPP8yyZctYs2YNI0eOLLmvastl4cRhZtYOPvjBDzJnzhxmz57NoYceyqBBg7jnnnuA5PkZa9asAWD79u30798fgEWLFrW53zPOOGPPkNP999/PG2+8sWc/Rx55JIcddhjPPPMMy5cv31OnZ8+evPvuu22W218eqjKzrqOT50cbOXIkJ510EosXL+auu+7iqquu4oYbbuDdd99lypQpnHTSSVx//fVcfPHF9O/fn3HjxvHiiy9W3OdXvvIVpk6dyqhRozjzzDM55phjAJg4cSLz5s1j+PDhDB48mHHjxu2pM23aNIYPH86oUaNYsGBB2XL7y9OqW/XaaZJDT6tu7aUrTqveWbJMq+4eh+WmXIIws/rmcxxmZpaJE4eZ1bXuMNyet6z/hk4cZla3evXqxbZt25w8DkBEsG3bNnr16lV1HZ/jMLO6NWDAAFpaWti6dWtnh1LXevXqxYABA6ou78RhZnWrZ8+ebd55be0v18QhaSJwE9AAfC8iZhVtV7r9PGAn8OmIWJVuWwBcALwWEScW1DkK+AHQCGwCPhERb+TZDmtfZ/VYWWaLL8c1qwe5neOQ1ADMBSYBQ4GpkoYWFZsENKWvacAtBdsWAhNL7Hom8EhENAGPpMtmZtZB8jw5fgqwMSJeiIhdwGJgclGZycAdkVgO9JbUDyAiHgf+o8R+JwOt9+kvAj6WR/BmZlZankNV/YHNBcstwNgqyvQHtlTY7x9GxBaAiNgi6UOlCkmaRtKL2XOLvuXjikUrymwZ3aFxmFnHyLPHse90jlB8zVw1ZfZLRMyPiDERMaZv377tsUszMyPfxNECDCxYHgC8uh9liv22dTgr/fnaAcZpZmYZ5Jk4VgBNkgZJOhiYAiwpKrME+JQS44DtrcNQFSwBLkvfXwb8c3sGbWZmleWWOCJiNzADeAB4GvhhRKyXNF3S9LTYUuAFYCNwG3B1a31JdwPLgMGSWiRdkW6aBZwj6TngnHTZzMw6SK73cUTEUpLkULhuXsH7AK4pU3dqmfXbgLPaMUwzM8vAc1WZmVkmThxmZpaJE4eZmWXixGFmZpk4cZiZWSZOHGZmlokTh5mZZeLEYWZmmThxmJlZJn50rO1rw/2dHYGZ1TD3OMzMLBP3OKxq5R/YZGbdiROH5easHis7OwQzy4GHqszMLBMnDjMzy8RDVVY7Kl3NNXhSx8VhZhW5x2FmZpm4x2F1oXHmz0qu33R5me8+7qGY5cY9DjMzy8SJw8zMMvFQVXdQ7qRzmeEc3+hnZpW4x2FmZpk4cZiZWSYequoGyg093X5j173yqOxVWLPO7+BIzLoe9zjMzCwTJw4zM8sk16EqSROBm4AG4HsRMatou9Lt5wE7gU9HxKpKdSWNAOYBvYDdwNUR8es821E3/AAmM+sAufU4JDUAc4FJwFBgqqShRcUmAU3paxpwSxV1vwl8NSJGAF9Ol83MrIPkOVR1CrAxIl6IiF3AYmByUZnJwB2RWA70ltSvjboBfCB9/0Hg1RzbYGZmRfIcquoPbC5YbgHGVlGmfxt1Pws8IGk2SeL701IfLmkaSS+GY445Zr8aYB2r8o2HozPWKV3ezA5cnj0OlVgXVZapVPcq4HMRMRD4HHB7qQ+PiPkRMSYixvTt27fKkM3MrC15Jo4WYGDB8gD2HVYqV6ZS3cuAH6fv7yEZ1jIzsw6SZ+JYATRJGiTpYGAKsKSozBLgU0qMA7ZHxJY26r4KnJm+/zPguRzbYGZmRXI7xxERuyXNAB4guaR2QUSslzQ93T4PWEpyKe5GkstxL69UN931Z4CbJB0EvE16HsPMzDpGrvdxRMRSkuRQuG5ewfsArqm2brr+l/jMp5lZp/FcVd1YufmczvJ8AmZWgf9EmJlZJlUlDkkn5h2ImZnVh2qHqualVzctBL4fEW/mFpF1mLN6rOzsEMysDlXV44iI8cAnSe6taJb0fUnn5BqZmZnVpKrPcUTEc8CXgM+T3EcxR9Izkj6eV3BmZlZ7qj3HMVzSPwBPk9x09+cRcXz6/h9yjM/MzGpMtec4bgZuA66LiLdaV0bEq5K+lEtkZmZWk6pNHOcBb0XEewCSegC9ImJnRNyZW3RmZlZzqk0cDwNnAzvS5cOABykzpblZZyt/xdj5HRqHWVdU7cnxXhHRmjRI3x+WT0hmZlbLqk0c/yVpVOuCpNHAWxXKm5lZF1XtUNVngXsktT4Tox9wSS4RmZlZTasqcUTECklDgMEkT+d7JiLezTUyMzOrSVlmxz0ZaEzrjJRERNyRS1RmRTw9ilntqCpxSLoT+DCwGngvXR2AE4eZWTdTbY9jDDA0ffCSmZl1Y9VeVbUO+KM8AzEzs/pQbY+jD/CUpF8D77SujIgLc4nKzMxqVrWJ4/o8gzAzs/pR7eW4j0n6E6ApIh6WdBjQkG9oltUVi1Z0dghm1g1UO636Z4AfAbemq/oDP80pJjMzq2HVnhy/BjgN+B3seajTh/IKyszMale1ieOdiNjVuiDpIJL7OMzMrJupNnE8Juk64ND0WeP3APfmF5aZmdWqahPHTGAr8CRwJbCU5PnjZmbWzVSVOCLi/Yi4LSIujoiL0vdtDlVJmihpg6SNkmaW2C5Jc9Lta4umbi9bV9K16bb1kr5ZTRvMzKx9VDtX1YuUOKcREcdWqNMAzAXOAVqAFZKWRMRTBcUmAU3payxwCzC2Ul1JHwUmA8Mj4h1JPklvZtaBssxV1aoXcDFwVBt1TgE2RsQLAJIWk/zBL0wck4E70t7Lckm9JfUjmYW3XN2rgFkR8Q5ARLxWZRvMzKwdVDtUta3g9UpE/F/gz9qo1h/YXLDckq6rpkyluscBp0v6d0mPSTq51IdLmiapWVLz1q1b2wjVzMyqVe1Q1aiCxR4kPZAj2qpWYl3xcFe5MpXqHgQcCYwjeUbIDyUdW3zOJSLmA/MBxowZ40uHzczaSbVDVX9f8H43sAn4RBt1WoCBBcsDgFerLHNwhbotwI/TRPFrSe+TTMLoboWZWQeodq6qj+7HvlcATZIGAa8AU4C/KiqzBJiRnsMYC2yPiC2Stlao+1OSYbJHJR1HkmRe34/46teG+zs7gm6vcebPSq7fNOv8Do7ErONVO1T1t5W2R8S3S6zbLWkG8ADJhIgLImK9pOnp9nkk94OcB2wEdgKXV6qb7noBsEDSOmAXcJkfMGVm1nGyXFV1MkkPAeDPgcfZ+wT2PiJiKUlyKFw3r+B9kMyDVVXddP0u4NIq4zYzs3aW5UFOoyLiPwEkXQ/cExF/k1dgZrkoN8w3eFLHxmFWx6qdcuQYkmGhVrtI7rUwM7Nuptoex50kVzD9hOSy2L8A7sgtKjMzq1nVXlX1dUn3A6enqy6PiCfyC8vMzGpVtUNVAIcBv4uIm4CW9FJZMzPrZqq9HPcrJFdWDQb+EegJ/D+SpwKa1T3fl2FWvWrPcfwFMBJYBRARr0pqa8oRs7pxVo+VZbY4cZgVq3aoald6z0UASPqD/EIyM7NaVm3i+KGkW4Hekj4DPAzcll9YZmZWq9ocqpIk4AfAEOB3JOc5vhwRD+Ucm5mZ1aA2E0dEhKSfRsRowMnCzKybq3aoanm5ByaZmVn3Uu1VVR8FpkvaBPwXyYOWIiKG5xWYmZnVpoqJQ9IxEfEy4BngzMwMaLvH8VOSWXFfkvRPEfGXHRCTmZnVsLbOcRQ++/vYPAMxM7P60FbiiDLvzcysm2prqOokSb8j6Xkcmr6H358c/0Cu0Zm1sysWrch1/57zyrqDiokjIho6KhAzM6sPWaZVNzMzq/o+DqsheQ+3WIFyzyg368bc4zAzs0ycOMzMLBMPVZnth3IPfnrk/dEdHIlZx3OPw8zMMnHiMDOzTHJNHJImStogaaOkmSW2S9KcdPtaSaMy1P07SSGpT55tMDOzveWWOCQ1AHNJZtYdCkyVNLSo2CSgKX1NA26ppq6kgcA5wMt5xW9mZqXl2eM4BdgYES9ExC5gMTC5qMxk4I5ILCd5pnm/Kur+A/C/8PxZZmYdLs+rqvoDmwuWW4CxVZTpX6mupAuBVyJiTfI49NIkTSPpxXDMMcfsXwvMMip3tRV4rirrOvLscZT6q17cQyhXpuR6SYcBXwS+3NaHR8T8iBgTEWP69u3bZrBmZladPBNHCzCwYHkA8GqVZcqt/zAwCFiTPsZ2ALBK0h+1a+RmZlZWnoljBdAkaZCkg4EpwJKiMkuAT6VXV40DtkfElnJ1I+LJiPhQRDRGRCNJghkVEb/JsR1mZlYgt3McEbFb0gzgAaABWBAR6yVNT7fPA5YC5wEbgZ3A5ZXq5hWrmZlVL9cpRyJiKUlyKFw3r+B9ANdUW7dEmcYDj9LMzLLwXFVmFXgKe7N9ecoRMzPLxInDzMwy8VCVWWcq94TBwZM6Ng6zDNzjMDOzTJw4zMwsEycOMzPLxInDzMwyceIwM7NMnDjMzCwTJw4zM8vE93HUsnLX+FuXUW5Kk9tvbL/7OBpn/qzk+k2z/HAp2z/ucZiZWSZOHGZmlomHqsw6QLnhorPKfHUrVx48xGSdzz0OMzPLxInDzMwyceIwM7NMnDjMzCwTJw4zM8vEV1XVMD/vuus4q8fKdivfOLP0el9tZR3FPQ4zM8vEicPMzDLxUFUt8JxUZlZH3OMwM7NMnDjMzCyTXBOHpImSNkjaKGmfa0GUmJNuXytpVFt1JX1L0jNp+Z9I6p1nG8zMbG+5JQ5JDcBcYBIwFJgqaWhRsUlAU/qaBtxSRd2HgBMjYjjwLPCFvNpgZmb7yrPHcQqwMSJeiIhdwGJgclGZycAdkVgO9JbUr1LdiHgwInan9ZcDA3Jsg5mZFckzcfQHNhcst6TrqilTTV2AvwZKXpIkaZqkZknNW7duzRi6mZmVk2fiUIl1UWWZNutK+iKwG7ir1IdHxPyIGBMRY/r27VtFuGZmVo087+NoAQYWLA8AXq2yzMGV6kq6DLgAOCsiipORmZnlKM/EsQJokjQIeAWYAvxVUZklwAxJi4GxwPaI2CJpa7m6kiYCnwfOjIidOcZv1rVlvfF08KRMxcs9xdBzatW/3BJHROyWNAN4AGgAFkTEeknT0+3zgKXAecBGYCdweaW66a5vBg4BHpIEsDwipufVDjMz21uuU45ExFKS5FC4bl7B+wCuqbZuuv4j7RymmZll4LmqzLq48lO0n9yhcVjX4SlHzMwsEycOMzPLxInDzMwyceIwM7NMfHK8BvjZ4mZWT5w4zLqKGnuSZPmruXwDYL3zUJWZmWXiHkdHqbFvg1a/yn2Tv2JRtv2UGyK9/bKM93f4/3a34x6HmZll4sRhZmaZeKjKzKpyxXVfK7k+89CW1T33OMzMLBMnDjMzy8RDVR3EN/lZvcj6f7W9/m+Xe/BTOZsur/C91w+dypV7HGZmlokTh5mZZeKhKjOrCeVubHzk/dEl11caIrv9sjIbMg5hWWnucZiZWSZOHGZmlomHqsysQ5W7gumsdvwaW24Y65H332+/D+nG3OMwM7NMnDjMzCwTD1W1s3Lz+ZhZovwDntqnfHsqe2NguZsPy1y1lfcNhpVunszjJkb3OMzMLBMnDjMzyyTXoSpJE4GbgAbgexExq2i70u3nATuBT0fEqkp1JR0F/ABoBDYBn4iIN/JqQ9b5c9rzyhAz6xhZbz5sr/2Xe/569r87lYbz6mioSlIDMBeYBAwFpkoaWlRsEtCUvqYBt1RRdybwSEQ0AY+ky2Zm1kHy/H58CrAxIl6IiF3AYmByUZnJwB2RWA70ltSvjbqTgdanKy8CPpZjG8zMrIgiIp8dSxcBEyPib9Ll/w6MjYgZBWXuA2ZFxC/T5UeAz5MMQ5WsK+nNiOhdsI83IuLIEp8/jaQXAzAY2LCfTekDvL6fdWuN21J7uko7wG2pVQfSlj+JiL7FK/M8x6ES64qzVLky1dStKCLmA/Oz1ClFUnNEjDnQ/dQCt6X2dJV2gNtSq/JoS55DVS3AwILlAcCrVZapVPe36XAW6c/X2jFmMzNrQ56JYwXQJGmQpIOBKcCSojJLgE8pMQ7YHhFb2qi7BGidNPky4J9zbIOZmRXJbagqInZLmgE8QHJJ7YKIWC9perp9HrCU5FLcjSSX415eqW6661nADyVdAbwMXJxXG1IHPNxVQ9yW2tNV2gFuS61q97bkdnLczMy6Jt+uZmZmmThxmJlZJk4cFUiaKGmDpI2S6uoOdUmbJD0pabWk5nTdUZIekvRc+nOf+19qgaQFkl6TtK5gXdnYJX0hPUYbJP23zom6tDJtuV7SK+mxWS3pvIJtNdkWSQMl/ULS05LWS/qf6fq6Oy4V2lKPx6WXpF9LWpO25avp+nyPS0T4VeJFclL+eeBY4GBgDTC0s+PKEP8moE/Rum8CM9P3M4FvdHacZWI/AxgFrGsrdpIpadYAhwCD0mPW0NltaKMt1wN/V6JszbYF6AeMSt8fATybxlt3x6VCW+rxuAg4PH3fE/h3YFzex8U9jvKqmTKl3tTFdC0R8TjwH0Wry8U+GVgcEe9ExIskV+id0hFxVqNMW8qp2bZExJZIJyCNiP8Engb6U4fHpUJbyqnltkRE7EgXe6avIOfj4sRRXn9gc8FyC5X/c9WaAB6UtDKdfgXgDyO5T4b054c6LbrsysVer8dphqS16VBW6zBCXbRFUiMwkuTbbV0fl6K2QB0eF0kNklaT3Az9UETkflycOMo74GlPOtlpETGKZIbhaySd0dkB5aQej9MtwIeBEcAW4O/T9TXfFkmHA/8EfDYiflepaIl1td6WujwuEfFeRIwgmWHjFEknVijeLm1x4iivmilTalZEvJr+fA34CUl3tJ6naykXe90dp4j4bfrL/j5wG78fKqjptkjqSfKH9q6I+HG6ui6PS6m21OtxaRURbwKPAhPJ+bg4cZRXzZQpNUnSH0g6ovU9cC6wjvqerqVc7EuAKZIOkTSI5Nkuv+6E+KrW+gud+guSYwM13BZJAm4Hno6IbxdsqrvjUq4tdXpc+krqnb4/FDgbeIa8j0tnXxVQyy+S6VCeJbny4IudHU+GuI8luXJiDbC+NXbgaJKHXz2X/jyqs2MtE//dJEMF75J8Q7qiUuzAF9NjtAGY1NnxV9GWO4EngbXpL3K/Wm8LMJ5kSGMtsDp9nVePx6VCW+rxuAwHnkhjXgd8OV2f63HxlCNmZpaJh6rMzCwTJw4zM8vEicPMzDJx4jAzs0ycOMzMLBMnDuuyJH0xnTF0bTrb6djOjulASFoo6aIc9ntdwfvGwpl8zUpx4rAuSdKpwAUks6AOJ7kxanPlWt3WdW0XMfs9Jw7rqvoBr0fEOwAR8Xqk07BIGi3psXQCyAcKpmYYnT7XYJmkb7V+85b0aUk3t+5Y0n2SJqTvz03Lr5J0Tzr/UevzUL6arn9S0pB0/eGS/jFdt1bSX1baTzkV2vCopG+kz2h4VtLp6frDJP0w/cwfSPp3SWMkzQIOTXtkd6W7b5B0W9pbezC9I9lsDycO66oeBAamfzy/K+lM2DNH0XeAiyJiNLAA+Hpa5x+B/xERp1bzAZL6AF8Czo5kQslm4G8Liryerr8F+Lt03f8GtkfEsLQn9PMq9lP8uZXaAHBQRJwCfBb4SrruauCN9DP/DzAaICJmAm9FxIiI+GRatgmYGxEnAG8Cf1nNv4d1Hwd1dgBmeYiIHZJGA6cDHwV+oOQpjs3AicBDyZRFNABbJH0Q6B0Rj6W7uJNkZuFKxpE8GOdX6b4OBpYVbG+dCHAl8PH0/dkk8561xvmGpAva2E+xwaXaUOZzG9P344Gb0s9cJ2lthf2/GBGrS+zDDHDisC4sIt4jmS30UUlPkkz2thJYX9yrSCeKKzf/zm727p33aq1G8vyDqWXqvZP+fI/f/66pxOe0tZ9iokQbqvjcar1T8P49wENVthcPVVmXJGmwpKaCVSOAl0gmduubnjxHUk9JJ0QyJfV2SePT8p8sqLsJGCGph6SB/H667eXAaZI+ku7rMEnHtRHag8CMgjiP3I/9lGxDG5/7S+ATafmhwLCCbe+mw19mVXHisK7qcGCRpKfSYZmhwPWRPAb4IuAbktaQzIz6p2mdy4G5kpYBbxXs61fAiyQzp84GWh87uhX4NHB3+hnLgSFtxHUDcKSkdennfzTrftpoQznfJUk2a4HPk8ymuj3dNh9YW3By3Kwiz45rVoKSR4reFxGVnqZWNyQ1AD0j4m1JHyaZavu4NAmZZeJzHGbdw2HAL9IhKQFXOWnY/nKPw8zMMvE5DjMzy8SJw8zMMnHiMDOzTJw4zMwsEycOMzPL5P8Dl75NdF69ep8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(sampled_batch.mask.sum(-1).cpu().numpy(), 50, label=\"Sampled\", density=True, range=(0, 300));\n",
    "plt.hist(real_batch.mask.sum(-1).cpu().numpy(), 50, alpha=0.3, label=\"Real data\", density=True, range=(0, 300));\n",
    "plt.xlabel(\"Sequence length\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.legend()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch] *",
   "language": "python",
   "name": "conda-env-torch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
